import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
from sklearn.linear_model import SGDClassifier
from sklearn.model_selection import cross_val_score
from sklearn.metrics import precision_score, recall_score
import matplotlib as mpl
import matplotlib.pyplot as plt

mnist = pd.read_csv('./datasets/mnist_784.csv')
X = np.array(mnist)
X = X[:,0:784]
y = mnist['class']
# 60000 training, 10000 testing
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000],
                                   y[60000:]

sgd = SGDClassifier(random_state=60, max_iter=5, tol=1e-3, verbose=0)
sgd.fit(X_train, y_train)

score = np.mean(cross_val_score(sgd, X, y, cv=3))
y_pred_train = sgd.predict(X_train)
y_pred_test = sgd.predict(X_test)
precision_train = precision_score(y_train, y_pred_train, 
                                  average = 'weighted')
recall_train = recall_score(y_train, y_pred_train, average = 'weighted')
precision_test = precision_score(y_test, y_pred_test, 
                                 average = 'weighted')
recall_test = recall_score(y_test, y_pred_test, average = 'weighted')
 
print('Score: ' + str(score))
print('Precision. Train: ' + str(precision_train) + ', Test: ' +
      str(precision_test))
print('Recall. Train: ' + str(recall_train) + ', Test: ' + 
      str(recall_test))

NUM_PREDICTIONS = 15
fig, axs = plt.subplots(1,NUM_PREDICTIONS,figsize=(10,10))
X_test = X_test[:, np.newaxis]
y_test = np.array(y_test)

for i in range(NUM_PREDICTIONS):
    buffer = X_test[i].reshape(28,28)
    axs[i].imshow(buffer, cmap = mpl.cm.binary, interpolation = 'nearest')
    axs[i].axis("off")
    
out=''    
for i in range(NUM_PREDICTIONS):
    if y_test[i] == sgd.predict(X_test[i]):
        out = out + str(i)+': success; '
    else:
        out = out + str(i)+': fail; '
print(out)
